#!/usr/bin/python3 -u 

__author__ = "Ibrahim Abu Kharmeh"
__copyright__ = "Copyright 2021, Huawei R&D Bristol"
__license__ = "BSD 2-Clause"
__version__ = "0.4.0"

from elf import elf
import os
import sys
from argparse import ArgumentParser
from collections import Counter
from glob import glob
import utils
import itertools
import re
from time import time


parser = ArgumentParser()
parser.add_argument("-f", "--file", dest="filename",
                    help="Input ELF file", metavar="FILE",nargs="+")
parser.add_argument("-a", "--actions", dest="actions",
                    help="Analysis actions",nargs="+")
parser.add_argument("-r", "--resultsdir", dest="resultsdir",
                    help="Results directory")
parser.add_argument("-v", "--verbosity", dest='verbosity',
                    help="Change default verbosity level",type=int,default=0)
parser.add_argument("-j", "--sections", dest='sections',
                    help="Specify sections to analyse")


# Get the current working directory from PWD
# helpfull when we call script from symbolic link ! 
pwd = os.environ["PWD"]


def main():

    args = parser.parse_args()
    global verbosity
    start_time = time()
    # Check that we have enough args
    if (args.filename is None or args.actions is None):
        print("Not enough arguments used")
        parser.print_help(sys.stderr)
        sys.exit(-1)

    # Parse the input file names !
    inputfiles = list()
    for elffile in args.filename:
        if (len(glob(pwd+"/"+elffile)) > 0):
            inputfiles.extend(glob(pwd+"/"+elffile))
        elif (len(glob(elffile)) > 0):
            inputfiles.extend(glob(elffile))
        else:
            print(elffile + " was not found")

    
    action_responses = list()
    errors = 0
    verbosity =  args.verbosity

    analysed_instructions = {}
    complete_set_size = 0

    for elfname in inputfiles:
        program = elf(elfname,pwd,verbosity,args.sections)
        if (program.init == False):
            continue
        program.dprint("Processing", elfname.split("/")[-1],colour="yellow")
        program.construct_main_dict()
        original_dict_size = program.calculate_dict_size()

        for action in args.actions:

            options = action.split(",")[1:]
            action = action.split(",")[0]
    
            if (action == "size"):
                if (len(options) == 0):
                    # size = program.calculate_exec_size()
                    current_dict_size = program.calculate_dict_size() + program.rodata
                    saving_percentage = str("{:.2f}".format((1-(current_dict_size/original_dict_size))*100)) if (original_dict_size > 0) else "0.00"

                    # The following prints executable sections sizes, the dictionary size and the optimized dictionary size 
                    if (args.resultsdir is None):
                        print(elfname.split("/")[-1]+ ",",original_dict_size,",",current_dict_size, ",", saving_percentage+"%")
                    else:
                        action_responses.append((elfname.split("/")[-1].split(".")[0] , str(original_dict_size), saving_percentage+"%"))
                else:
                    # Prints the size of each function in the dict, usefull to pinpoint which functions the most ! 
                    lsize = program.calculate_dict_size(mode=1)
                    if (args.resultsdir is None):
                        for file in lsize:
                            print(file,lsize[file])
                    else:
                        action_responses.extend(lsize)

            elif (action == "dict"):
                program.print_dict()

            elif (action == "peek"):
                if (options is None):
                    print("Peek needs a list of field passed in options")
                    sys.exit(-6)
                program.print_options(options)

            # "Immediate Coverage, Mainly used to evaluate IMM length benefit"
            elif (action == "coverage"):
                if (len(options)<2):
                    print("Coverage needs instruction name and signedness passed in options")
                    sys.exit(-6)

                offsets = program.retrieve_field(req_instructions = [options[0]],options=["Immediate"])
                offsets_frequency = Counter([int(offset["Immediate"]) for offset in offsets ])

                coverage = utils.immediate_coverage(offsets_frequency,options[1],12)

                action_responses.append(coverage)
            elif (action in ["c_sb","c_sh","c_lhu","c_lbu","c_lh","c_lb"]):
                offsets = program.retrieve_field(req_instructions = [action[2:]],additional_condition= lambda instruction : utils.fit_in_field(int(instruction["Immediate"]),"unsigned",4) 
                                                and utils.reg_within(instruction['Destination'][0])  and utils.reg_within(instruction['Source'][0]) )
                program.apply_optimization(old_instructions = offsets,optimization={"Instruction":"c."+action.replace("_","."),'WoE':16},mode=2)

            elif (action == "sp"):
                lb_sb_conditions = (lambda instruction: instruction["Source"] == ('sp',) and utils.reg_within(instruction['Destination'][0]) and int(instruction["Immediate"]) <= 2**5 )
                lhu_shu_conditions = (lambda instruction: instruction["Source"] == ('sp',) and utils.reg_within(instruction['Destination'][0]) and int(instruction["Immediate"]) <= 2**6  and int(instruction["Immediate"])%2==0 )

                if(options[0] == "lbu"):
                    offsets = program.retrieve_field(req_instructions=["lbu"],additional_condition=lb_sb_conditions)
                    program.apply_optimization(old_instructions=offsets,optimization={"Instruction":"c.lbusp",'WoE':16},mode=2)
                elif(options[0] == "lhu"):
                    offsets = program.retrieve_field(req_instructions=["lhu"],additional_condition=lhu_shu_conditions)
                    program.apply_optimization(old_instructions=offsets,optimization={"Instruction":"c.lhusp",'WoE':16},mode=2)
                elif(options[0] == "sb"):
                    offsets = program.retrieve_field(req_instructions=["sb"], additional_condition=lb_sb_conditions)
                    program.apply_optimization(old_instructions=offsets,optimization={"Instruction":"c.sbsp",'WoE':16},mode=2)
                elif(options[0] == "sh"):
                    offsets = program.retrieve_field(req_instructions=["sh"], additional_condition=lhu_shu_conditions)
                    program.apply_optimization(old_instructions=offsets,optimization={"Instruction":"c.shsp",'WoE':16},mode=2)
            elif (action == "c_not"):
                NOTS = program.retrieve_field(req_instructions=["not"],additional_condition=lambda instruction: utils.reg_within(instruction['Destination'][0]) and instruction['Destination'][0] == instruction["Source"][0])
                program.apply_optimization(old_instructions=NOTS,optimization={"Instruction":"c.not",'WoE':16},mode=2)
            elif (action == "c_neg"):
                NEG = program.retrieve_field(req_instructions=["neg"], additional_condition=lambda instruction: utils.reg_within(instruction['Destination'][0]) and 
                                            instruction['Destination'][0] == instruction["Source"][0])
                program.apply_optimization(old_instructions=NEG,optimization={"Instruction":"c.neg",'WoE':16},mode=2)
            elif (action == "c_mul"):
                mult = program.retrieve_field(req_instructions=["mul"], additional_condition = (lambda instruction: instruction['Destination'][0] in instruction["Source"]
                    and utils.reg_within(instruction["Source"][0]) and utils.reg_within(instruction["Source"][1])))
                program.apply_optimization(old_instructions=mult,optimization={"Instruction":"c.mul",'WoE':16},mode=2)
            elif (action == "DECBGEZ"): 
                beqnz = program.find_source_dependencies(mode=0,target_inst=["bnez"],woe_insensitive=True,traverseback=True)
                beqnz_stat = utils.filter_instruction_chains(beqnz,["andi"],woe_insensitive=True)

                beqz = program.find_source_dependencies(mode=0,target_inst=["beqz"],woe_insensitive=True,traverseback=True)
                beqz_stat = utils.filter_instruction_chains(beqz,["andi"],woe_insensitive=True)
                # action_responses.append(Counter([x[0]["Instruction"] for x in beqz]))

                bgtz = program.find_source_dependencies(mode=0,target_inst=["bgtz"],woe_insensitive=True,traverseback=True)
                bgtz_stat = utils.filter_instruction_chains(bgtz,["addi"],woe_insensitive=True)

                bltz = program.find_source_dependencies(mode=0,target_inst=["bltz"],woe_insensitive=True,traverseback=True)
                bltz_stat = utils.filter_instruction_chains(bltz,["addi"],woe_insensitive=True)

                print(",",program.calculate_dict_size(),",",len(beqnz_stat),",",len(beqz_stat),",",len(bgtz_stat),",",len(bltz_stat))

            elif (action in ["c_sext_b","c_sext_h"]):
                # "C_SEXT_W" is simply a pesudoinstruction for C.ADDIW, so we dont need to analyse for it ! ! 
                shift_range = {"c_sext_b":"0x18","c_sext_h":"0x10"}
                inst = action.replace("_",".")
                sext = program.find_dependant_insts(mode=0,target_inst="slli",woe_insensitive=True,append_target_inst=True,
                                                    start_condition = (lambda instruction : (instruction["Immediate"] == shift_range[action])))
                filtered_sext = utils.filter_instruction_chains(sext,["srai"],woe_insensitive=True,
                                                                additional_condition= (lambda instruction: instruction["Immediate"] == shift_range[action]  and utils.reg_within(instruction['Destination'][0])))
                program.apply_optimization(old_instructions=filtered_sext,optimization={"Instruction":action.replace("_","."),'WoE':16},mode=0)
            elif (action in ["c_zext_b","c_zext_h","c_zext_w"]):
                if (action == "c_zext_w" and program.xlen == 32):
                    # We wont find any in 32bit file, so there is no point of looking ! 
                    continue
                elif (action == "c_zext_b"):
                    # If its sign extend byte, then it would be andi with 255 (this does not fit in the compressed range, thats why we need an instruction for it) !
                    zext = program.retrieve_field(req_instructions = ["andi"],additional_condition= (lambda instruction:  (instruction['Destination'][0] == instruction["Source"][0])  
                                                    and instruction["Immediate"] == "255"  and utils.reg_within(instruction['Destination'][0]) ))
                    program.apply_optimization(old_instructions=zext,optimization={"Instruction":"c.zext.b",'WoE':16},mode=2)
                else:
                    # Otherwise it would be two logical shift !!
                    shift_range = {"c_zext_h":"0x10","c_zext_w":"0x20"}
                    zext = program.find_dependant_insts(mode=0,target_inst="slli",woe_insensitive=True,append_target_inst=True,
                                                        start_condition = (lambda instruction : (instruction["Immediate"] == shift_range[action])))
                    filtered_zext = utils.filter_instruction_chains(zext,["srli"],woe_insensitive=True,
                                                        additional_condition= (lambda instruction: instruction["Immediate"] == shift_range[action]  and utils.reg_within(instruction['Destination'][0])))
                    program.apply_optimization(old_instructions=filtered_zext,optimization={"Instruction":action.replace("_","."),'WoE':16},mode=0)

            elif (action in ["beqi","bnei","bgei","bgeui","blti","bltui"]):
                # cmpimm[4:0] == utils.fit_in_field(int(instruction["Immediate"]),"unsigned",5)
                # Offset range is same as normal SB class instructions, thus we don't need to check for it
                load_immediate = program.find_source_dependencies(mode=0,target_inst=[action[:-1]],woe_insensitive=True)
                branch_loads = utils.filter_instruction_chains(load_immediate,["li"],woe_insensitive=True,additional_condition= (lambda instruction: utils.fit_in_field(int(instruction["Immediate"]),"unsigned",5)))
                program.apply_optimization(old_instructions=branch_loads,optimization={"Instruction":action,'WoE':32},mode=0)
            elif (action == "pushpop"):
                compressed = True if (len(options) > 0) else False 
                (push,pop) = program.find_push_pop(compressed)
                for adj in push:
                    if(adj["rcount_val"] != 0):
                        program.apply_optimization([adj["Adj"]]+adj["Opt"]+adj["Embedded_moves"], {
                                                   "Instruction": adj["Instruction"], 'WoE': adj['WoE'], "Source": adj["rcount_val"]}, 1)
                for adj in pop:
                    if(adj["rcount_val"] != 0):
                        pop_tbr = []
                        pop_tbr += [adj["Adj"]]
                        pop_tbr += adj["Opt"]
                        if ("Ret" in adj): pop_tbr += [adj["Ret"]]
                        if ("ret_val" in adj): pop_tbr += [adj["ret_val"]]
                        program.apply_optimization(pop_tbr, {
                                                    "Instruction": adj["Instruction"], 'WoE': adj['WoE'], 'Destination': adj["rcount_val"]}, 1)
                        
            # Instruction frequency, used to check the frequency of instruction occurrences therough a given file, or within a given instruction category ! 
            elif (action == "instfreq"):
                req_category = options[0] if (len(options) > 0) else None
                action_responses.append(Counter(program.retrieve_insts(req_cat=req_category)))
            # Report the frequency the program needed to construct long address relative to its sections !
            # Used to have a general idea  which sections needed to be targetted for GP as sometimes its not obvious if sections have unusual names other than SBSS etc 
            elif (action == "longaddress"):
                program.read_elf_sections()
                lw_sw_long = program.retrieve_field(req_instructions = ["lw","sw"],options=["Address"],exact=True)
                section_counter = {section["Section"]:0 for section in program.sections} 
                section_counter["unknown"] = 0
                for inst in lw_sw_long:
                    identified = False
                    for section in program.sections:
                        if (int(inst["Address"],16) >= section["Address"] and int(inst["Address"],16) <= section["Address"]+section["Size"]):
                            section_counter[section["Section"]] +=1
                            identified = True
                            break
                    if not (identified):
                        section_counter["unknown"] +=1

                for x in sorted(section_counter.items(), key = (lambda x : x[1]),reverse=True):
                    if(x[1]>0): print (x[0],":",x[1])

            elif (action == "instchain"):
                if (options is None):
                    #If no instruction passed when analysing instruction pairs, then we retrive all instructions in the file and we test againest them
                    Instructions_Frequency = Counter(program.retrieve_insts())
                    Instruction_target = Instructions_Frequency.most_common(50)
                    for Instruction in Instruction_target:
                        program.dprint("Instruction", Instruction[0] )
                        mnemonics = program.find_dependant_insts(mode=1,target_inst=Instruction[0])
                        if (Instruction[0] not in analysed_instructions):
                            analysed_instructions[Instruction[0]] =  utils.count_insts(mnemonics,1)
                        else:
                            analysed_instructions[Instruction[0]] =  sum((analysed_instructions[Instruction[0]],utils.count_insts(mnemonics,1)),Counter())

                else:
                    mnemonics = program.find_dependant_insts(mode=0,target_inst = options[0])
                    if (options[0] not in analysed_instructions):
                        analysed_instructions[options[0]] =  utils.count_insts(mnemonics,1)
                    else:
                        analysed_instructions[options[0]] =  sum((analysed_instructions[options[0]],utils.count_insts(mnemonics,1)),Counter())

            elif (action in ["c_tblj", "c_tbljal","c_tbljalm"]):

                tbljal_boundary = {"c_tblj":("zero",56), "c_tbljal":("ra",192),"c_tbljalm":("t0",8)}

                AUIPC_DEP = program.find_dependant_insts(mode=0,target_inst="auipc",woe_insensitive=True,append_target_inst=True)
                JALR_JR = utils.filter_instruction_chains(AUIPC_DEP,["jalr","jr"],additional_condition= lambda instruction : "Target_address" in instruction and instruction["Destination"][0] == tbljal_boundary[action][0])
                JAL_J = program.retrieve_field(req_instructions=["jal","j"],additional_condition= lambda instruction : "Target_address" in instruction and instruction["Destination"][0]  == tbljal_boundary[action][0])

                jumps = [jmp[1] for jmp in JALR_JR] + JAL_J 
                table_jump = []
                table_jump_frequency = []
                jmp_frequency_selection = []

                # Create a new list for eligable jumps without PC and other various fields so we can check frequency
                for jump in jumps:
                    table_jump.append((jump["Instruction"],jump['Destination'],jump['Target_address']))

                # Calculate the minimum number of jumps to be feasible for replacement ! 
                # For when XLEN=32, each JAL J replacement cost us 32bits in RO data and saves us 16bits, thus we need it to be used at least 3 times for it to be beneficial 
                #                   each JALR JR replacement cost us 32bits in RO data and saves us 48bits, thus we can replace it even if its used a single time 

                # For when XLEN=64, each JAL J replacement cost us 64bits in RO data and saves us 16bits, thus we need it to be used at least 5 times for it to be beneficial 
                #                   each JALR JR replacement cost us 64bits in RO data and saves us 48bits, thus we need it to be used at least 2 times for it to be beneficial 

                jr_jalr_threashold = 0 if (program.xlen == 32) else 2
                j_jal_threashold = 3 if (program.xlen == 32) else 5
                def change_weights(tbl_freq):
                    output = {}
                    for jmp in tbl_freq:
                        if (jmp[0] in ["jr","jalr"]):
                            if (jmp[1:] in output): 
                                output[jmp[1:]] = output[jmp[1:]] + tbl_freq[jmp]
                            else:
                                output[jmp[1:]] =  0 if (tbl_freq[jmp] < jr_jalr_threashold) else tbl_freq[jmp]*3
                        elif (jmp[0] in ["jal", "j"]):
                            if (jmp[1:] in output): 
                                output[jmp[1:]] = output[jmp[1:]] + tbl_freq[jmp]
                            else:
                                output[jmp[1:]] = 0 if (tbl_freq[jmp] < j_jal_threashold) else tbl_freq[jmp]
                    return(output)
                table_jump_frequency = change_weights(Counter(table_jump))

                jmp_frequency_selection  = ([jmp[0] for jmp in Counter(table_jump_frequency).most_common(tbljal_boundary[action][1]) if jmp[1] > 0])
                selected_jalr_jr = list()
                selected_jal_j = list()
                for jump in JALR_JR:
                    if((jump[1]['Destination'],jump[1]['Target_address']) in jmp_frequency_selection):
                        selected_jalr_jr.append(jump)
                for jump in JAL_J:
                    if((jump['Destination'],jump['Target_address']) in jmp_frequency_selection):
                        selected_jal_j.append(jump)

                program.apply_optimization(selected_jalr_jr,{"Instruction":action.replace("_","."),'WoE' : 16},0)
                program.apply_optimization(selected_jal_j,{"Instruction":action.replace("_","."),'WoE' : 16},2)
                program.add_ro_data((len(jmp_frequency_selection)* (program.xlen/8)))
               
            elif (action in ["c_mva01s07","c_mva23s07","c_mva01s03","c_mva23s03"]):
                (_,chains) = program.find_mv_chains(inst=action)
                program.apply_optimization(old_instructions=chains,optimization={"Instruction":action.replace("_","."),'WoE':16},mode=0)
            elif (action == "muli"):
                # MULI Only ! rd' = rs1' * sign_ext(imm)
                # Bitlen for imm is 12 which is the same as li/addi, no need to filter imm range start_condition = (lambda entry: utils.fit_in_field(int(entry["Immediate"]),"signed",12)
                # It uses complete register range, thus no need to filter for compressed register format ! 
                load_immediates = program.find_dependant_insts(mode=0, target_inst="li", woe_insensitive=True, append_target_inst=True)
                muli = utils.filter_instruction_chains(load_immediates, ["mul"],woe_insensitive=True)
                independant_muli = [x for x in muli if len(x) == 2]
                program.apply_optimization(old_instructions=independant_muli, optimization={"Instruction": "muli", 'WoE': 32}, mode=0)
            elif (action == "addiadd"):
                # ADDIADD rd' = rs1' + rs2' + sign_ext(imm)
                # Bitlen for imm is 9
                load_immediates = program.find_dependant_insts(mode=0, target_inst="addi", woe_insensitive=True, append_target_inst=True,
                                start_condition = (lambda entry: utils.fit_in_field(int(entry["Immediate"]),"signed",9)))
                addiadd = utils.filter_instruction_chains(load_immediates, ["add"])
                independant_addiadd = [x for x in addiadd if len(x) == 2 and (x[0]['WoE'] != 16 or x[1]['WoE'] != 16 )]
                program.apply_optimization(old_instructions=independant_addiadd, optimization={
                                               "Instruction": "addiadd", 'WoE': 32}, mode=0)
            elif (action == "muliadd"):
                # MULIADD rd' = rs1' + rs2' * sign_ext(imm)
                load_immediates = program.find_dependant_insts(mode=0, target_inst="li", woe_insensitive=True, append_target_inst=True,
                                start_condition = (lambda entry: utils.fit_in_field(int(entry["Immediate"]),"signed",9)),
                                pass_through=(lambda entry: entry["Instruction"] == "mul" or utils.categorise_inst(entry["Instruction"]) == "alu_imm"))
                # Make sure that each chain has mul and add and that they contain only 3 entries, if len>3, it means that li value or mul value "intermidate result"
                # was used for another operation !
                instructions_required=set(["mul", "add"])
                muliadd = []
                for mnenomic in load_immediates:
                    instructions_found = [inst["Instruction"][2:] if ("c." in inst["Instruction"]) else inst["Instruction"] for inst in mnenomic]
                    if (len(instructions_required.difference(set(instructions_found))) != 0 and "mul" in instructions_found):
                        starting_condition = mnenomic[instructions_found.index("mul")]
                        extended = program.find_dependant_insts_from(mode=0,starting_inst=starting_condition,pass_through = lambda entry : utils.categorise_inst(entry["Instruction"]) == "alu_imm")
                        instructions_found = [inst["Instruction"][2:] if ("c." in inst["Instruction"]) else inst["Instruction"] for inst in extended]
                        if ("add" in instructions_found):
                            mnenomic.extend(extended)
                            instructions_found = [inst["Instruction"][2:] if ("c." in inst["Instruction"]) else inst["Instruction"] for inst in mnenomic]
                        else:
                            continue
                    
                    if(len(instructions_required.difference(set(instructions_found))) == 0):
                        if (len(mnenomic) == 3):
                            muliadd.append(mnenomic)
                            # Sometimes we have an immediate arithmatic operation where order wont matter since the intermidate values are not used ! 
                        elif (len(mnenomic) == 4 and len(set(instructions_found)) == 4):
                            muliadd.append([instruction for instruction in mnenomic if utils.categorise_inst(instruction["Instruction"]) != "alu_imm"])

                program.apply_optimization(old_instructions=muliadd, optimization={"Instruction": "MULIADD", 'WoE': 32}, mode=0)
      
            else:
                print("Action:",action,"not supported")

    # Format and outputs the results from all the previous actions !
    if (args.resultsdir is None):
        fd = sys.stdout
    else:
        filename = "-".join(args.actions).replace(",","_")
        fd = open(args.resultsdir+"/"+filename+".csv","w")
    
    if (len(action_responses)>0  and isinstance(action_responses[0],Counter)):
        utils.print_counter(sum(action_responses,Counter()),fd)
    elif (args.actions[-1].split(",")[0] == "size"):
        fd.write("File name,File Size,Savings\n")
        for response in sorted(action_responses, key = lambda x: float(x[2][0:-1]), reverse = True):
            fd.write(",".join(response)+"\n")
    elif (args.actions[-1].split(",")[0] == "coverage"):
        for bits,tresponse in enumerate(zip(*action_responses)):
            fd.write(str(bits)+","+str(sum([x[0] for x in tresponse]))+","+str(sum([x[1] for x in tresponse]))+"\n")          
    elif (len(action_responses) > 0 and isinstance(action_responses[0],str)):
        for response in action_responses:
            fd.write(response+"\n")

    if (args.resultsdir is not None):
        fd.close()

    if (verbosity>0): print("Total Execution time was", time()-start_time)
    return errors



if __name__ == "__main__":
    main()
